{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZjRn1KazJQqn"
      },
      "source": [
        "Copyright 2020 Google LLC.\n",
        "\n",
        "Licensed aunder the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "wHZQpP65JGvj"
      },
      "outputs": [],
      "source": [
        "#@title License\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Y0iMnWR_nk9E"
      },
      "source": [
        "# RepNet\n",
        "\n",
        "This colab contains a pre-trained [RepNet](https://sites.google.com/view/repnet) model. It can be used to predict the rate at which repetitions are happening in a video in a class-agnostic manner. These estimates can be used to count the number of repetitions in videos.\n",
        "\n",
        "This model is able to count repetitions in many domains: counting the number of reps while exercising, measuring the rate of biological events like heartrates etc.\n",
        "\n",
        "Ensure you are running the Colab with a GPU runtime.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P3i6Lyvg0VfT"
      },
      "source": [
        "# Setup\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "76L5XFonl_Bw"
      },
      "outputs": [],
      "source": [
        "#@title\n",
        "\n",
        "import base64\n",
        "import io\n",
        "import os\n",
        "import time\n",
        "\n",
        "import cv2\n",
        "\n",
        "from IPython.display import display\n",
        "from IPython.display import HTML\n",
        "from IPython.display import Javascript\n",
        "\n",
        "import matplotlib\n",
        "from matplotlib.animation import FuncAnimation\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import numpy as np\n",
        "from PIL import Image\n",
        "from scipy.signal import medfilt\n",
        "\n",
        "import tensorflow.compat.v2 as tf\n",
        "\n",
        "! pip install youtube_dl\n",
        "import youtube_dl\n",
        "\n",
        "from google.colab import drive\n",
        "from google.colab import output\n",
        "from google.colab.output import eval_js\n",
        "\n",
        "# Model definition\n",
        "layers = tf.keras.layers\n",
        "regularizers = tf.keras.regularizers\n",
        "\n",
        "\n",
        "class ResnetPeriodEstimator(tf.keras.models.Model):\n",
        "  \"\"\"RepNet model.\"\"\"\n",
        "\n",
        "  def __init__(\n",
        "      self,\n",
        "      num_frames=64,\n",
        "      image_size=112,\n",
        "      base_model_layer_name='conv4_block3_out',\n",
        "      temperature=13.544,\n",
        "      dropout_rate=0.25,\n",
        "      l2_reg_weight=1e-6,\n",
        "      temporal_conv_channels=512,\n",
        "      temporal_conv_kernel_size=3,\n",
        "      temporal_conv_dilation_rate=3,\n",
        "      conv_channels=32,\n",
        "      conv_kernel_size=3,\n",
        "      transformer_layers_config=((512, 4, 512),),\n",
        "      transformer_dropout_rate=0.0,\n",
        "      transformer_reorder_ln=True,\n",
        "      period_fc_channels=(512, 512),\n",
        "      within_period_fc_channels=(512, 512)):\n",
        "    super(ResnetPeriodEstimator, self).__init__()\n",
        "\n",
        "    # Model params.\n",
        "    self.num_frames = num_frames\n",
        "    self.image_size = image_size\n",
        "\n",
        "    self.base_model_layer_name = base_model_layer_name\n",
        "\n",
        "    self.temperature = temperature\n",
        "\n",
        "    self.dropout_rate = dropout_rate\n",
        "    self.l2_reg_weight = l2_reg_weight\n",
        "\n",
        "    self.temporal_conv_channels = temporal_conv_channels\n",
        "    self.temporal_conv_kernel_size = temporal_conv_kernel_size\n",
        "    self.temporal_conv_dilation_rate = temporal_conv_dilation_rate\n",
        "\n",
        "    self.conv_channels = conv_channels\n",
        "    self.conv_kernel_size = conv_kernel_size\n",
        "    # Transformer config in form of (channels, heads, bottleneck channels).\n",
        "    self.transformer_layers_config = transformer_layers_config\n",
        "    self.transformer_dropout_rate = transformer_dropout_rate\n",
        "    self.transformer_reorder_ln = transformer_reorder_ln\n",
        "\n",
        "    self.period_fc_channels = period_fc_channels\n",
        "    self.within_period_fc_channels = within_period_fc_channels\n",
        "\n",
        "    # Base ResNet50 Model.\n",
        "    base_model = tf.keras.applications.ResNet50V2(\n",
        "        include_top=False, weights=None, pooling='max')\n",
        "    self.base_model = tf.keras.models.Model(\n",
        "        inputs=base_model.input,\n",
        "        outputs=base_model.get_layer(self.base_model_layer_name).output)\n",
        "\n",
        "    # 3D Conv on k Frames\n",
        "    self.temporal_conv_layers = [\n",
        "        layers.Conv3D(self.temporal_conv_channels,\n",
        "                      self.temporal_conv_kernel_size,\n",
        "                      padding='same',\n",
        "                      dilation_rate=(self.temporal_conv_dilation_rate, 1, 1),\n",
        "                      kernel_regularizer=regularizers.l2(self.l2_reg_weight),\n",
        "                      kernel_initializer='he_normal')]\n",
        "    self.temporal_bn_layers = [layers.BatchNormalization()\n",
        "                               for _ in self.temporal_conv_layers]\n",
        "\n",
        "    # Counting Module (Self-sim \u003e Conv \u003e Transformer \u003e Classifier)\n",
        "    self.conv_3x3_layer = layers.Conv2D(self.conv_channels,\n",
        "                                        self.conv_kernel_size,\n",
        "                                        padding='same',\n",
        "                                        activation=tf.nn.relu)\n",
        "\n",
        "    channels = self.transformer_layers_config[0][0]\n",
        "    self.input_projection = layers.Dense(\n",
        "        channels, kernel_regularizer=regularizers.l2(self.l2_reg_weight),\n",
        "        activation=None)\n",
        "    self.input_projection2 = layers.Dense(\n",
        "        channels, kernel_regularizer=regularizers.l2(self.l2_reg_weight),\n",
        "        activation=None)\n",
        "\n",
        "    length = self.num_frames\n",
        "    self.pos_encoding = tf.compat.v1.get_variable(\n",
        "        name='resnet_period_estimator/pos_encoding',\n",
        "        shape=[1, length, 1],\n",
        "        initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.02))\n",
        "    self.pos_encoding2 = tf.compat.v1.get_variable(\n",
        "        name='resnet_period_estimator/pos_encoding2',\n",
        "        shape=[1, length, 1],\n",
        "        initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.02))\n",
        "\n",
        "    self.transformer_layers = []\n",
        "    for d_model, num_heads, dff in self.transformer_layers_config:\n",
        "      self.transformer_layers.append(\n",
        "          TransformerLayer(d_model, num_heads, dff,\n",
        "                           self.transformer_dropout_rate,\n",
        "                           self.transformer_reorder_ln))\n",
        "\n",
        "    self.transformer_layers2 = []\n",
        "    for d_model, num_heads, dff in self.transformer_layers_config:\n",
        "      self.transformer_layers2.append(\n",
        "          TransformerLayer(d_model, num_heads, dff,\n",
        "                           self.transformer_dropout_rate,\n",
        "                           self.transformer_reorder_ln))\n",
        "\n",
        "    # Period Prediction Module.\n",
        "    self.dropout_layer = layers.Dropout(self.dropout_rate)\n",
        "    num_preds = self.num_frames//2\n",
        "    self.fc_layers = []\n",
        "    for channels in self.period_fc_channels:\n",
        "      self.fc_layers.append(layers.Dense(\n",
        "          channels, kernel_regularizer=regularizers.l2(self.l2_reg_weight),\n",
        "          activation=tf.nn.relu))\n",
        "    self.fc_layers.append(layers.Dense(\n",
        "        num_preds, kernel_regularizer=regularizers.l2(self.l2_reg_weight)))\n",
        "\n",
        "    # Within Period Module\n",
        "    num_preds = 1\n",
        "    self.within_period_fc_layers = []\n",
        "    for channels in self.within_period_fc_channels:\n",
        "      self.within_period_fc_layers.append(layers.Dense(\n",
        "          channels, kernel_regularizer=regularizers.l2(self.l2_reg_weight),\n",
        "          activation=tf.nn.relu))\n",
        "    self.within_period_fc_layers.append(layers.Dense(\n",
        "        num_preds, kernel_regularizer=regularizers.l2(self.l2_reg_weight)))\n",
        "\n",
        "  def call(self, x):\n",
        "    # Ensures we are always using the right batch_size during train/eval.\n",
        "    batch_size = tf.shape(x)[0]\n",
        "    # Conv Feature Extractor.\n",
        "    x = tf.reshape(x, [-1, self.image_size, self.image_size, 3])\n",
        "    x = self.base_model(x)\n",
        "    h = tf.shape(x)[1]\n",
        "    w = tf.shape(x)[2]\n",
        "    c = tf.shape(x)[3]\n",
        "    x = tf.reshape(x, [batch_size, -1, h, w, c])\n",
        "\n",
        "    # 3D Conv to give temporal context to per-frame embeddings.\n",
        "    for bn_layer, conv_layer in zip(self.temporal_bn_layers,\n",
        "                                    self.temporal_conv_layers):\n",
        "      x = conv_layer(x)\n",
        "      x = bn_layer(x)\n",
        "      x = tf.nn.relu(x)\n",
        "\n",
        "    x = tf.reduce_max(x, [2, 3])\n",
        "\n",
        "    # Reshape and prepare embs for output.\n",
        "    final_embs = x\n",
        "\n",
        "    # Get self-similarity matrix.\n",
        "    x = get_sims(x, self.temperature)\n",
        "\n",
        "    # 3x3 conv layer on self-similarity matrix.\n",
        "    x = self.conv_3x3_layer(x)\n",
        "    x = tf.reshape(x, [batch_size, self.num_frames, -1])\n",
        "    within_period_x = x\n",
        "\n",
        "    # Period prediction.\n",
        "    x = self.input_projection(x)\n",
        "    x += self.pos_encoding\n",
        "    for transformer_layer in self.transformer_layers:\n",
        "      x = transformer_layer(x)\n",
        "    x = flatten_sequential_feats(x, batch_size, self.num_frames)\n",
        "    for fc_layer in self.fc_layers:\n",
        "      x = self.dropout_layer(x)\n",
        "      x = fc_layer(x)\n",
        "\n",
        "    # Within period prediction.\n",
        "    within_period_x = self.input_projection2(within_period_x)\n",
        "    within_period_x += self.pos_encoding2\n",
        "    for transformer_layer in self.transformer_layers2:\n",
        "      within_period_x = transformer_layer(within_period_x)\n",
        "    within_period_x = flatten_sequential_feats(within_period_x,\n",
        "                                               batch_size,\n",
        "                                               self.num_frames)\n",
        "    for fc_layer in self.within_period_fc_layers:\n",
        "      within_period_x = self.dropout_layer(within_period_x)\n",
        "      within_period_x = fc_layer(within_period_x)\n",
        "\n",
        "    return x, within_period_x, final_embs\n",
        "\n",
        "  @tf.function\n",
        "  def preprocess(self, imgs):\n",
        "    imgs = tf.cast(imgs, tf.float32)\n",
        "    imgs -= 127.5\n",
        "    imgs /= 127.5\n",
        "    imgs = tf.image.resize(imgs, (self.image_size, self.image_size))\n",
        "    return imgs\n",
        "\n",
        "\n",
        "def get_sims(embs, temperature):\n",
        "  \"\"\"Calculates self-similarity between batch of sequence of embeddings.\"\"\"\n",
        "  batch_size = tf.shape(embs)[0]\n",
        "  seq_len = tf.shape(embs)[1]\n",
        "  embs = tf.reshape(embs, [batch_size, seq_len, -1])\n",
        "\n",
        "  def _get_sims(embs):\n",
        "    \"\"\"Calculates self-similarity between sequence of embeddings.\"\"\"\n",
        "    dist = pairwise_l2_distance(embs, embs)\n",
        "    sims = -1.0 * dist\n",
        "    return sims\n",
        "\n",
        "  sims = tf.map_fn(_get_sims, embs)\n",
        "  sims /= temperature\n",
        "  sims = tf.nn.softmax(sims, axis=-1)\n",
        "  sims = tf.expand_dims(sims, -1)\n",
        "  return sims\n",
        "\n",
        "\n",
        "def flatten_sequential_feats(x, batch_size, seq_len):\n",
        "  \"\"\"Flattens sequential features with known batch size and seq_len.\"\"\"\n",
        "  x = tf.reshape(x, [batch_size, seq_len, -1])\n",
        "  return x\n",
        "\n",
        "\n",
        "# Transformer from https://www.tensorflow.org/tutorials/text/transformer .\n",
        "def scaled_dot_product_attention(q, k, v, mask):\n",
        "  \"\"\"Calculate the attention weights.\n",
        "\n",
        "  q, k, v must have matching leading dimensions.\n",
        "  k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.\n",
        "  The mask has different shapes depending on its type(padding or look ahead)\n",
        "  but it must be broadcastable for addition.\n",
        "\n",
        "  Args:\n",
        "    q: query shape == (..., seq_len_q, depth)\n",
        "    k: key shape == (..., seq_len_k, depth)\n",
        "    v: value shape == (..., seq_len_v, depth_v)\n",
        "    mask: Float tensor with shape broadcastable\n",
        "          to (..., seq_len_q, seq_len_k). Defaults to None.\n",
        "\n",
        "  Returns:\n",
        "    outputs: shape == (..., seq_len_q, depth_v)\n",
        "    attention_weights: shape == (..., seq_len_q, seq_len_k)\n",
        "  \"\"\"\n",
        "\n",
        "  matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)\n",
        "\n",
        "  # scale matmul_qk.\n",
        "  dk = tf.cast(tf.shape(k)[-1], tf.float32)\n",
        "  scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)\n",
        "\n",
        "  # add the mask to the scaled tensor.\n",
        "  if mask is not None:\n",
        "    scaled_attention_logits += (mask * -1e9)\n",
        "\n",
        "  # softmax is normalized on the last axis (seq_len_k) so that the scores\n",
        "  # add up to 1.\n",
        "  # (..., seq_len_q, seq_len_k)\n",
        "  attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)\n",
        "\n",
        "  outputs = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)\n",
        "\n",
        "  return outputs, attention_weights\n",
        "\n",
        "\n",
        "def point_wise_feed_forward_network(d_model, dff):\n",
        "  return tf.keras.Sequential([\n",
        "      tf.keras.layers.Dense(dff, activation='relu'),\n",
        "      tf.keras.layers.Dense(d_model)\n",
        "  ])\n",
        "\n",
        "\n",
        "class MultiHeadAttention(tf.keras.layers.Layer):\n",
        "  \"\"\"Multi-headed attention layer.\"\"\"\n",
        "\n",
        "  def __init__(self, d_model, num_heads):\n",
        "    super(MultiHeadAttention, self).__init__()\n",
        "    self.num_heads = num_heads\n",
        "    self.d_model = d_model\n",
        "\n",
        "    assert d_model % self.num_heads == 0\n",
        "\n",
        "    self.depth = d_model // self.num_heads\n",
        "\n",
        "    self.wq = tf.keras.layers.Dense(d_model)\n",
        "    self.wk = tf.keras.layers.Dense(d_model)\n",
        "    self.wv = tf.keras.layers.Dense(d_model)\n",
        "\n",
        "    self.dense = tf.keras.layers.Dense(d_model)\n",
        "\n",
        "  def split_heads(self, x, batch_size):\n",
        "    \"\"\"Split the last dimension into (num_heads, depth).\"\"\"\n",
        "    x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))\n",
        "    return tf.transpose(x, perm=[0, 2, 1, 3])\n",
        "\n",
        "  def call(self, v, k, q, mask):\n",
        "    batch_size = tf.shape(q)[0]\n",
        "\n",
        "    q = self.wq(q)  # (batch_size, seq_len, d_model)\n",
        "    k = self.wk(k)  # (batch_size, seq_len, d_model)\n",
        "    v = self.wv(v)  # (batch_size, seq_len, d_model)\n",
        "\n",
        "    q = self.split_heads(\n",
        "        q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)\n",
        "    k = self.split_heads(\n",
        "        k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)\n",
        "    v = self.split_heads(\n",
        "        v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)\n",
        "\n",
        "    # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)\n",
        "    # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)\n",
        "    scaled_attention, attention_weights = scaled_dot_product_attention(\n",
        "        q, k, v, mask)\n",
        "\n",
        "    scaled_attention = tf.transpose(\n",
        "        scaled_attention,\n",
        "        perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)\n",
        "\n",
        "    concat_attention = tf.reshape(\n",
        "        scaled_attention,\n",
        "        (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)\n",
        "\n",
        "    outputs = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)\n",
        "\n",
        "    return outputs, attention_weights\n",
        "\n",
        "\n",
        "class TransformerLayer(tf.keras.layers.Layer):\n",
        "  \"\"\"Implements a single transformer layer (https://arxiv.org/abs/1706.03762).\n",
        "  \"\"\"\n",
        "\n",
        "  def __init__(self, d_model, num_heads, dff,\n",
        "               dropout_rate=0.1,\n",
        "               reorder_ln=False):\n",
        "    super(TransformerLayer, self).__init__()\n",
        "\n",
        "    self.mha = MultiHeadAttention(d_model, num_heads)\n",
        "    self.ffn = point_wise_feed_forward_network(d_model, dff)\n",
        "\n",
        "    self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)\n",
        "    self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)\n",
        "\n",
        "    self.dropout1 = tf.keras.layers.Dropout(dropout_rate)\n",
        "    self.dropout2 = tf.keras.layers.Dropout(dropout_rate)\n",
        "\n",
        "    self.reorder_ln = reorder_ln\n",
        "\n",
        "  def call(self, x):\n",
        "    inp_x = x\n",
        "\n",
        "    if self.reorder_ln:\n",
        "      x = self.layernorm1(x)\n",
        "\n",
        "    # (batch_size, input_seq_len, d_model)\n",
        "    attn_output, _ = self.mha(x, x, x, mask=None)\n",
        "    attn_output = self.dropout1(attn_output)\n",
        "\n",
        "    if self.reorder_ln:\n",
        "      out1 = inp_x + attn_output\n",
        "      x = out1\n",
        "    else:\n",
        "      # (batch_size, input_seq_len, d_model)\n",
        "      out1 = self.layernorm1(x + attn_output)\n",
        "      x = out1\n",
        "\n",
        "    if self.reorder_ln:\n",
        "      x = self.layernorm2(x)\n",
        "\n",
        "    # (batch_size, input_seq_len, d_model)\n",
        "    ffn_output = self.ffn(x)\n",
        "    ffn_output = self.dropout2(ffn_output)\n",
        "\n",
        "    if self.reorder_ln:\n",
        "      out2 = out1 + ffn_output\n",
        "    else:\n",
        "      # (batch_size, input_seq_len, d_model)\n",
        "      out2 = self.layernorm2(out1 + ffn_output)\n",
        "\n",
        "    return out2\n",
        "\n",
        "\n",
        "def pairwise_l2_distance(a, b):\n",
        "  \"\"\"Computes pairwise distances between all rows of a and all rows of b.\"\"\"\n",
        "  norm_a = tf.reduce_sum(tf.square(a), 1)\n",
        "  norm_a = tf.reshape(norm_a, [-1, 1])\n",
        "  norm_b = tf.reduce_sum(tf.square(b), 1)\n",
        "  norm_b = tf.reshape(norm_b, [1, -1])\n",
        "  dist = tf.maximum(norm_a - 2.0 * tf.matmul(a, b, False, True) + norm_b, 0.0)\n",
        "  return dist\n",
        "\n",
        "\n",
        "def get_repnet_model(logdir):\n",
        "  \"\"\"Returns a trained RepNet model.\n",
        "\n",
        "  Args:\n",
        "    logdir (string): Path to directory where checkpoint will be downloaded.\n",
        "\n",
        "  Returns:\n",
        "    model (Keras model): Trained RepNet model.\n",
        "  \"\"\"\n",
        "  # Check if we are in eager mode.\n",
        "  assert tf.executing_eagerly()\n",
        "\n",
        "  # Models will be called in eval mode.\n",
        "  tf.keras.backend.set_learning_phase(0)\n",
        "\n",
        "  # Define RepNet model.\n",
        "  model = ResnetPeriodEstimator()\n",
        "  # tf.function for speed.\n",
        "  model.call = tf.function(model.call)\n",
        "\n",
        "  # Define checkpoint and checkpoint manager.\n",
        "  ckpt = tf.train.Checkpoint(model=model)\n",
        "  ckpt_manager = tf.train.CheckpointManager(\n",
        "      ckpt, directory=logdir, max_to_keep=10)\n",
        "  latest_ckpt = ckpt_manager.latest_checkpoint\n",
        "  print('Loading from: ', latest_ckpt)\n",
        "  if not latest_ckpt:\n",
        "    raise ValueError('Path does not have a checkpoint to load.')\n",
        "  # Restore weights.\n",
        "  ckpt.restore(latest_ckpt).expect_partial()\n",
        "\n",
        "  # Pass dummy frames to build graph.\n",
        "  model(tf.random.uniform((1, 64, 112, 112, 3)))\n",
        "  return model\n",
        "\n",
        "\n",
        "def unnorm(query_frame):\n",
        "  min_v = query_frame.min()\n",
        "  max_v = query_frame.max()\n",
        "  query_frame = (query_frame - min_v) / max(1e-7, (max_v - min_v))\n",
        "  return query_frame\n",
        "\n",
        "\n",
        "def create_count_video(frames,\n",
        "                       per_frame_counts,\n",
        "                       within_period,\n",
        "                       score,\n",
        "                       fps,\n",
        "                       output_file,\n",
        "                       delay,\n",
        "                       plot_count=True,\n",
        "                       plot_within_period=False,\n",
        "                       plot_score=False):\n",
        "  \"\"\"Creates video with running count and within period predictions.\n",
        "\n",
        "  Args:\n",
        "    frames (List): List of images in form of NumPy arrays.\n",
        "    per_frame_counts (List): List of floats indicating repetition count for\n",
        "      each frame. This is the rate of repetition for that particular frame.\n",
        "      Summing this list up gives count over entire video.\n",
        "    within_period (List): List of floats indicating score between 0 and 1 if the\n",
        "      frame is inside the periodic/repeating portion of a video or not.\n",
        "    score (float): Score between 0 and 1 indicating the confidence of the\n",
        "      RepNet model's count predictions.\n",
        "    fps (int): Frames per second of the input video. Used to scale the\n",
        "      repetition rate predictions to Hz.\n",
        "    output_file (string): Path of the output video.\n",
        "    delay (integer): Delay between each frame in the output video.\n",
        "    plot_count (boolean): if True plots the count in the output video.\n",
        "    plot_within_period (boolean): if True plots the per-frame within period\n",
        "      scores.\n",
        "    plot_score (boolean): if True plots the confidence of the model along with\n",
        "      count ot within_period scores.\n",
        "  \"\"\"\n",
        "  if output_file[-4:] not in ['.mp4', '.gif']:\n",
        "    raise ValueError('Output format can only be mp4 or gif')\n",
        "  num_frames = len(frames)\n",
        "\n",
        "  running_counts = np.cumsum(per_frame_counts)\n",
        "  final_count = running_counts[-1]\n",
        "\n",
        "  def count(idx):\n",
        "    return int(np.round(running_counts[idx]))\n",
        "\n",
        "  def rate(idx):\n",
        "    return per_frame_counts[idx] * fps\n",
        "\n",
        "  if plot_count and not plot_within_period:\n",
        "    fig = plt.figure(figsize=(10, 12), tight_layout=True)\n",
        "    im = plt.imshow(unnorm(frames[0]))\n",
        "    if plot_score:\n",
        "      plt.suptitle('Pred Count: %d, '\n",
        "                   'Prob: %0.1f' % (int(np.around(final_count)), score),\n",
        "                   fontsize=24)\n",
        "\n",
        "    plt.title('Count 0, Rate: 0', fontsize=24)\n",
        "    plt.axis('off')\n",
        "    plt.grid(visible=False)\n",
        "    def update_count_plot(i):\n",
        "      \"\"\"Updates the count plot.\"\"\"\n",
        "      im.set_data(unnorm(frames[i]))\n",
        "      plt.title('Count %d, Rate: %0.4f Hz' % (count(i), rate(i)), fontsize=24)\n",
        "\n",
        "    anim = FuncAnimation(\n",
        "        fig,\n",
        "        update_count_plot,\n",
        "        frames=np.arange(1, num_frames),\n",
        "        interval=delay,\n",
        "        blit=False)\n",
        "    if output_file[-3:] == 'mp4':\n",
        "      anim.save(output_file, dpi=100, fps=24)\n",
        "    elif output_file[-3:] == 'gif':\n",
        "      anim.save(output_file, writer='imagemagick', fps=24, dpi=100)\n",
        "\n",
        "  elif plot_within_period:\n",
        "    fig, axs = plt.subplots(1, 2, figsize=(12, 6))\n",
        "    im = axs[0].imshow(unnorm(frames[0]))\n",
        "    axs[1].plot(0, within_period[0])\n",
        "    axs[1].set_xlim((0, len(frames)))\n",
        "    axs[1].set_ylim((0, 1))\n",
        "\n",
        "    if plot_score:\n",
        "      plt.suptitle('Pred Count: %d, '\n",
        "                   'Prob: %0.1f' % (int(np.around(final_count)), score),\n",
        "                   fontsize=24)\n",
        "\n",
        "    if plot_count:\n",
        "      axs[0].set_title('Count 0, Rate: 0', fontsize=20)\n",
        "\n",
        "    plt.axis('off')\n",
        "    plt.grid(visible=False)\n",
        "\n",
        "    def update_within_period_plot(i):\n",
        "      \"\"\"Updates the within period plot along with count.\"\"\"\n",
        "      im.set_data(unnorm(frames[i]))\n",
        "      axs[0].set_xticks([])\n",
        "      axs[0].set_yticks([])\n",
        "      xs = []\n",
        "      ys = []\n",
        "      if plot_count:\n",
        "        axs[0].set_title('Count %d, Rate: %0.4f Hz' % (count(i), rate(i)),\n",
        "                         fontsize=20)\n",
        "      for idx in range(i):\n",
        "        xs.append(idx)\n",
        "        ys.append(within_period[int(idx * len(within_period) / num_frames)])\n",
        "      axs[1].clear()\n",
        "      axs[1].set_title('Within Period or Not', fontsize=20)\n",
        "      axs[1].set_xlim((0, num_frames))\n",
        "      axs[1].set_ylim((-0.05, 1.05))\n",
        "      axs[1].plot(xs, ys)\n",
        "\n",
        "    anim = FuncAnimation(\n",
        "        fig,\n",
        "        update_within_period_plot,\n",
        "        frames=np.arange(1, num_frames),\n",
        "        interval=delay,\n",
        "        blit=False,\n",
        "    )\n",
        "    if output_file[-3:] == 'mp4':\n",
        "      anim.save(output_file, dpi=100, fps=24)\n",
        "    elif output_file[-3:] == 'gif':\n",
        "      anim.save(output_file, writer='imagemagick', fps=24, dpi=100)\n",
        "\n",
        "  plt.close()\n",
        "\n",
        "\n",
        "def show_video(video_path):\n",
        "  mp4 = open(video_path, 'rb').read()\n",
        "  data_url = 'data:video/mp4;base64,' + base64.b64encode(mp4).decode()\n",
        "  return HTML(\"\"\"\u003cvideo width=600 controls\u003e\n",
        "      \u003csource src=\"%s\" type=\"video/mp4\"\u003e\u003c/video\u003e\n",
        "  \"\"\" % data_url)\n",
        "\n",
        "\n",
        "def viz_reps(frames,\n",
        "             count,\n",
        "             score,\n",
        "             alpha=1.0,\n",
        "             pichart=True,\n",
        "             colormap=plt.cm.PuBu,\n",
        "             num_frames=None,\n",
        "             interval=30,\n",
        "             plot_score=True):\n",
        "  \"\"\"Visualize repetitions.\"\"\"\n",
        "  if isinstance(count, list):\n",
        "    counts = len(frames) * [count/len(frames)]\n",
        "  else:\n",
        "    counts = count\n",
        "  sum_counts = np.cumsum(counts)\n",
        "  tmp_path = '/tmp/output.mp4'\n",
        "  fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(5, 5),\n",
        "                         tight_layout=True,)\n",
        "\n",
        "  h, w, _ = np.shape(frames[0])\n",
        "  wedge_x = 95 / 112 * w\n",
        "  wedge_y = 17 / 112 * h\n",
        "  wedge_r = 15 / 112 * h\n",
        "  txt_x = 95 / 112 * w\n",
        "  txt_y = 19 / 112 * h\n",
        "  otxt_size = 62 / 112 * h\n",
        "\n",
        "  if plot_score:\n",
        "    plt.title('Score:%.2f' % score, fontsize=20)\n",
        "  im0 = ax.imshow(unnorm(frames[0]))\n",
        "\n",
        "  if not num_frames:\n",
        "    num_frames = len(frames)\n",
        "\n",
        "  if pichart:\n",
        "    wedge1 = matplotlib.patches.Wedge(\n",
        "        center=(wedge_x, wedge_y),\n",
        "        r=wedge_r,\n",
        "        theta1=0,\n",
        "        theta2=0,\n",
        "        color=colormap(1.),\n",
        "        alpha=alpha)\n",
        "    wedge2 = matplotlib.patches.Wedge(\n",
        "        center=(wedge_x, wedge_y),\n",
        "        r=wedge_r,\n",
        "        theta1=0,\n",
        "        theta2=0,\n",
        "        color=colormap(0.5),\n",
        "        alpha=alpha)\n",
        "\n",
        "    ax.add_patch(wedge1)\n",
        "    ax.add_patch(wedge2)\n",
        "    txt = ax.text(\n",
        "        txt_x,\n",
        "        txt_y,\n",
        "        '0',\n",
        "        size=35,\n",
        "        ha='center',\n",
        "        va='center',\n",
        "        alpha=0.9,\n",
        "        color='white',\n",
        "    )\n",
        "\n",
        "  else:\n",
        "    txt = ax.text(\n",
        "        txt_x,\n",
        "        txt_y,\n",
        "        '0',\n",
        "        size=otxt_size,\n",
        "        ha='center',\n",
        "        va='center',\n",
        "        alpha=0.8,\n",
        "        color=colormap(0.4),\n",
        "    )\n",
        "\n",
        "  def update(i):\n",
        "    \"\"\"Update plot with next frame.\"\"\"\n",
        "    im0.set_data(unnorm(frames[i]))\n",
        "    ctr = int(sum_counts[i])\n",
        "    if pichart:\n",
        "      if ctr%2 == 0:\n",
        "        wedge1.set_color(colormap(1.0))\n",
        "        wedge2.set_color(colormap(0.5))\n",
        "      else:\n",
        "        wedge1.set_color(colormap(0.5))\n",
        "        wedge2.set_color(colormap(1.0))\n",
        "\n",
        "      wedge1.set_theta1(-90)\n",
        "      wedge1.set_theta2(-90 - 360 * (1 - sum_counts[i] % 1.0))\n",
        "      wedge2.set_theta1(-90 - 360 * (1 - sum_counts[i] % 1.0))\n",
        "      wedge2.set_theta2(-90)\n",
        "\n",
        "    txt.set_text(int(sum_counts[i]))\n",
        "    ax.grid(False)\n",
        "    ax.set_xticks([])\n",
        "    ax.set_yticks([])\n",
        "    plt.tight_layout()\n",
        "\n",
        "  anim = FuncAnimation(\n",
        "      fig,\n",
        "      update,\n",
        "      frames=num_frames,\n",
        "      interval=interval,\n",
        "      blit=False)\n",
        "  anim.save(tmp_path, dpi=80)\n",
        "  plt.close()\n",
        "  return show_video(tmp_path)\n",
        "\n",
        "\n",
        "def record_video(interval_in_ms, num_frames, quality=0.8):\n",
        "  \"\"\"Capture video from webcam.\"\"\"\n",
        "  # https://colab.research.google.com/notebooks/snippets/advanced_outputs.ipynb.\n",
        "\n",
        "  # Give warning before recording.\n",
        "  for i in range(0, 3):\n",
        "    print('Opening webcam in %d seconds'%(3-i))\n",
        "    time.sleep(1)\n",
        "    output.clear('status_text')\n",
        "\n",
        "  js = Javascript('''\n",
        "    async function recordVideo(interval_in_ms, num_frames, quality) {\n",
        "      const div = document.createElement('div');\n",
        "      const video = document.createElement('video');\n",
        "      video.style.display = 'block';\n",
        "      const stream = await navigator.mediaDevices.getUserMedia({video: true});\n",
        "\n",
        "      // show the video in the HTML element\n",
        "      document.body.appendChild(div);\n",
        "      div.appendChild(video);\n",
        "      video.srcObject = stream;\n",
        "      await video.play();\n",
        "\n",
        "      google.colab.output.setIframeHeight(document.documentElement.scrollHeight,\n",
        "        true);\n",
        "\n",
        "      for (let i = 0; i \u003c num_frames; i++) {\n",
        "        const canvas = document.createElement('canvas');\n",
        "        canvas.width = video.videoWidth;\n",
        "        canvas.height = video.videoHeight;\n",
        "        canvas.getContext('2d').drawImage(video, 0, 0);\n",
        "        img = canvas.toDataURL('image/jpeg', quality);\n",
        "        google.colab.kernel.invokeFunction(\n",
        "        'notebook.get_webcam_video', [img], {});\n",
        "        await new Promise(resolve =\u003e setTimeout(resolve, interval_in_ms));\n",
        "      }\n",
        "      stream.getVideoTracks()[0].stop();\n",
        "      div.remove();\n",
        "    }\n",
        "    ''')\n",
        "  display(js)\n",
        "  eval_js('recordVideo({},{},{})'.format(interval_in_ms, num_frames, quality))\n",
        "\n",
        "\n",
        "def data_uri_to_img(uri):\n",
        "  \"\"\"Convert base64image to Numpy array.\"\"\"\n",
        "  image = base64.b64decode(uri.split(',')[1], validate=True)\n",
        "  # Binary string to PIL image.\n",
        "  image = Image.open(io.BytesIO(image))\n",
        "  image = image.resize((224, 224))\n",
        "  # PIL to Numpy array.\n",
        "  image = np.array(np.array(image, dtype=np.uint8), np.float32)\n",
        "  return image\n",
        "\n",
        "\n",
        "def read_video(video_filename, width=224, height=224):\n",
        "  \"\"\"Read video from file.\"\"\"\n",
        "  cap = cv2.VideoCapture(video_filename)\n",
        "  fps = cap.get(cv2.CAP_PROP_FPS)\n",
        "  frames = []\n",
        "  if cap.isOpened():\n",
        "    while True:\n",
        "      success, frame_bgr = cap.read()\n",
        "      if not success:\n",
        "        break\n",
        "      frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)\n",
        "      frame_rgb = cv2.resize(frame_rgb, (width, height))\n",
        "      frames.append(frame_rgb)\n",
        "  frames = np.asarray(frames)\n",
        "  return frames, fps\n",
        "\n",
        "\n",
        "def get_webcam_video(img_b64):\n",
        "  \"\"\"Populates global variable imgs by converting image URI to Numpy array.\"\"\"\n",
        "  image = data_uri_to_img(img_b64)\n",
        "  imgs.append(image)\n",
        "\n",
        "\n",
        "def download_video_from_url(url_to_video,\n",
        "                            path_to_video='/tmp/video.mp4'):\n",
        "  if os.path.exists(path_to_video):\n",
        "    os.remove(path_to_video)\n",
        "  ydl_opts = {\n",
        "      'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/mp4',\n",
        "      'outtmpl': str(path_to_video),\n",
        "  }\n",
        "  with youtube_dl.YoutubeDL(ydl_opts) as ydl:\n",
        "    ydl.download([url_to_video])\n",
        "\n",
        "\n",
        "def get_score(period_score, within_period_score):\n",
        "  \"\"\"Combine the period and periodicity scores.\"\"\"\n",
        "  within_period_score = tf.nn.sigmoid(within_period_score)[:, 0]\n",
        "  per_frame_periods = tf.argmax(period_score, axis=-1) + 1\n",
        "  pred_period_conf = tf.reduce_max(\n",
        "      tf.nn.softmax(period_score, axis=-1), axis=-1)\n",
        "  pred_period_conf = tf.where(\n",
        "      tf.math.less(per_frame_periods, 3), 0.0, pred_period_conf)\n",
        "  within_period_score *= pred_period_conf\n",
        "  within_period_score = np.sqrt(within_period_score)\n",
        "  pred_score = tf.reduce_mean(within_period_score)\n",
        "  return pred_score, within_period_score\n",
        "\n",
        "\n",
        "def get_counts(model, frames, strides, batch_size,\n",
        "               threshold,\n",
        "               within_period_threshold,\n",
        "               constant_speed=False,\n",
        "               median_filter=False,\n",
        "               fully_periodic=False):\n",
        "  \"\"\"Pass frames through model and conver period predictions to count.\"\"\"\n",
        "  seq_len = len(frames)\n",
        "  raw_scores_list = []\n",
        "  scores = []\n",
        "  within_period_scores_list = []\n",
        "\n",
        "  if fully_periodic:\n",
        "    within_period_threshold = 0.0\n",
        "\n",
        "  frames = model.preprocess(frames)\n",
        "\n",
        "  for stride in strides:\n",
        "    num_batches = int(np.ceil(seq_len/model.num_frames/stride/batch_size))\n",
        "    raw_scores_per_stride = []\n",
        "    within_period_score_stride = []\n",
        "    for batch_idx in range(num_batches):\n",
        "      idxes = tf.range(batch_idx*batch_size*model.num_frames*stride,\n",
        "                       (batch_idx+1)*batch_size*model.num_frames*stride,\n",
        "                       stride)\n",
        "      idxes = tf.clip_by_value(idxes, 0, seq_len-1)\n",
        "      curr_frames = tf.gather(frames, idxes)\n",
        "      curr_frames = tf.reshape(\n",
        "          curr_frames,\n",
        "          [batch_size, model.num_frames, model.image_size, model.image_size, 3])\n",
        "\n",
        "      raw_scores, within_period_scores, _ = model(curr_frames)\n",
        "      raw_scores_per_stride.append(np.reshape(raw_scores.numpy(),\n",
        "                                              [-1, model.num_frames//2]))\n",
        "      within_period_score_stride.append(np.reshape(within_period_scores.numpy(),\n",
        "                                                   [-1, 1]))\n",
        "    raw_scores_per_stride = np.concatenate(raw_scores_per_stride, axis=0)\n",
        "    raw_scores_list.append(raw_scores_per_stride)\n",
        "    within_period_score_stride = np.concatenate(\n",
        "        within_period_score_stride, axis=0)\n",
        "    pred_score, within_period_score_stride = get_score(\n",
        "        raw_scores_per_stride, within_period_score_stride)\n",
        "    scores.append(pred_score)\n",
        "    within_period_scores_list.append(within_period_score_stride)\n",
        "\n",
        "  # Stride chooser\n",
        "  argmax_strides = np.argmax(scores)\n",
        "  chosen_stride = strides[argmax_strides]\n",
        "  raw_scores = np.repeat(\n",
        "      raw_scores_list[argmax_strides], chosen_stride, axis=0)[:seq_len]\n",
        "  within_period = np.repeat(\n",
        "      within_period_scores_list[argmax_strides], chosen_stride,\n",
        "      axis=0)[:seq_len]\n",
        "  within_period_binary = np.asarray(within_period \u003e within_period_threshold)\n",
        "  if median_filter:\n",
        "    within_period_binary = medfilt(np.float32(within_period_binary), 5)\n",
        "    within_period_binary = within_period_binary.astype(bool)\n",
        "\n",
        "  # Select Periodic frames\n",
        "  periodic_idxes = np.where(within_period_binary)[0]\n",
        "\n",
        "  if constant_speed:\n",
        "    # Count by averaging predictions. Smoother but\n",
        "    # assumes constant speed.\n",
        "    scores = tf.reduce_mean(\n",
        "        tf.nn.softmax(raw_scores[periodic_idxes], axis=-1), axis=0)\n",
        "    max_period = np.argmax(scores)\n",
        "    pred_score = scores[max_period]\n",
        "    pred_period = chosen_stride * (max_period + 1)\n",
        "    per_frame_counts = (\n",
        "        np.asarray(seq_len * [1. / pred_period]) *\n",
        "        np.asarray(within_period_binary))\n",
        "  else:\n",
        "    # Count each frame. More noisy but adapts to changes in speed.\n",
        "    pred_score = tf.reduce_mean(within_period)\n",
        "    per_frame_periods = tf.argmax(raw_scores, axis=-1) + 1\n",
        "    per_frame_counts = tf.where(\n",
        "        tf.math.less(per_frame_periods, 3),\n",
        "        0.0,\n",
        "        tf.math.divide(1.0,\n",
        "                       tf.cast(chosen_stride * per_frame_periods, tf.float32)),\n",
        "    )\n",
        "    if median_filter:\n",
        "      per_frame_counts = medfilt(per_frame_counts, 5)\n",
        "\n",
        "    per_frame_counts *= np.asarray(within_period_binary)\n",
        "\n",
        "    pred_period = seq_len/np.sum(per_frame_counts)\n",
        "\n",
        "  if pred_score \u003c threshold:\n",
        "    print('No repetitions detected in video as score '\n",
        "          '%0.2f is less than threshold %0.2f.'%(pred_score, threshold))\n",
        "    per_frame_counts = np.asarray(len(per_frame_counts) * [0.])\n",
        "\n",
        "  return (pred_period, pred_score, within_period,\n",
        "          per_frame_counts, chosen_stride)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dPpgGXG_aalo"
      },
      "source": [
        "## Load trained RepNet"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c3-XygLIxwBK"
      },
      "outputs": [],
      "source": [
        "PATH_TO_CKPT = '/tmp/repnet_ckpt/'\n",
        "!mkdir $PATH_TO_CKPT\n",
        "!wget -nc -P $PATH_TO_CKPT https://storage.googleapis.com/repnet_ckpt/checkpoint\n",
        "!wget -nc -P $PATH_TO_CKPT https://storage.googleapis.com/repnet_ckpt/ckpt-88.data-00000-of-00002\n",
        "!wget -nc -P $PATH_TO_CKPT https://storage.googleapis.com/repnet_ckpt/ckpt-88.data-00001-of-00002\n",
        "!wget -nc -P $PATH_TO_CKPT https://storage.googleapis.com/repnet_ckpt/ckpt-88.index\n",
        "\n",
        "model = get_repnet_model(PATH_TO_CKPT)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CNR6sVATy8yX"
      },
      "source": [
        "## Set Params\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PaCB432zUrZu"
      },
      "outputs": [],
      "source": [
        "##@title\n",
        "\n",
        "# FPS while recording video from webcam.\n",
        "WEBCAM_FPS = 16#@param {type:\"integer\"}\n",
        "\n",
        "# Time in seconds to record video on webcam.\n",
        "RECORDING_TIME_IN_SECONDS = 8. #@param {type:\"number\"}\n",
        "\n",
        "# Threshold to consider periodicity in entire video.\n",
        "THRESHOLD = 0.2#@param {type:\"number\"}\n",
        "\n",
        "# Threshold to consider periodicity for individual frames in video.\n",
        "WITHIN_PERIOD_THRESHOLD = 0.5#@param {type:\"number\"}\n",
        "\n",
        "# Use this setting for better results when it is\n",
        "# known action is repeating at constant speed.\n",
        "CONSTANT_SPEED = False#@param {type:\"boolean\"}\n",
        "\n",
        "# Use median filtering in time to ignore noisy frames.\n",
        "MEDIAN_FILTER = True#@param {type:\"boolean\"}\n",
        "\n",
        "# Use this setting for better results when it is\n",
        "# known the entire video is periodic/reapeating and\n",
        "# has no aperiodic frames.\n",
        "FULLY_PERIODIC = False#@param {type:\"boolean\"}\n",
        "\n",
        "# Plot score in visualization video.\n",
        "PLOT_SCORE = False#@param {type:\"boolean\"}\n",
        "\n",
        "# Visualization video's FPS.\n",
        "VIZ_FPS = 30#@param {type:\"integer\"}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "clHlZdUuTzHF"
      },
      "source": [
        "# Get input video\n",
        "\n",
        "We provide 3 ways to get input video:\n",
        "1. upload video to your Google Drive.\n",
        "2. provide URL of a video.\n",
        "3. record video using webcam."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jgrs0nUjTwNJ"
      },
      "source": [
        "## Get Video from Drive\n",
        "\n",
        "If you have uploaded an input video to your Drive, update `PATH_TO_VIDEO_ON_YOUR_DRIVE` below.\n",
        "\n",
        "This step needs your authorization in order to access the videos from your Google Drive. Running the next code cell will ask for that permission."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qXzo176uTvL0"
      },
      "outputs": [],
      "source": [
        "# Uncomment and run these lines to load videos from Google Drive.\n",
        "\n",
        "# drive.mount('/content/gdrive')\n",
        "\n",
        "# PATH_TO_VIDEO_ON_YOUR_DRIVE = \"gdrive/My Drive/\u003cPATH_TO_VIDEO\u003e.mp4\"\n",
        "# imgs, vid_fps = read_video(PATH_TO_VIDEO_ON_YOUR_DRIVE)\n",
        "# show_video(PATH_TO_VIDEO_ON_YOUR_DRIVE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oFC7Y-egr-8E"
      },
      "source": [
        "## Get Video from URL\n",
        "\n",
        "Provide a link to mp4/gif hosted online."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IScNPW9C2Urp"
      },
      "outputs": [],
      "source": [
        "# Hummingbird flying.\n",
        "VIDEO_URL = 'https://imgur.com/t/hummingbird/m2e2Nfa'\n",
        "\n",
        "# Cheetah running.\n",
        "# VIDEO_URL = 'https://www.reddit.com/r/gifs/comments/4qfif6/cheetah_running_at_63_mph_102_kph/'\n",
        "\n",
        "# Exercise repetition counting.\n",
        "# VIDEO_URL = 'https://www.youtube.com/watch?v=5g1T-ff07kM'\n",
        "\n",
        "# Kitchen activities repetition counting. Tough example with many starts and\n",
        "# stops and varying speeds of action.\n",
        "# VIDEO_URL = 'https://www.youtube.com/watch?v=5EYY2J3nb5c'\n",
        "\n",
        "download_video_from_url(VIDEO_URL)\n",
        "imgs, vid_fps = read_video(\"/tmp/video.mp4\")\n",
        "show_video(\"/tmp/video.mp4\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RinBs-kRadNA"
      },
      "source": [
        "## Get Video from Webcam\n",
        "\n",
        "You will be asked for permission for this page to access your webcam."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OSpVZ5qsaz7Z"
      },
      "source": [
        "### Run following code to capture webcam video"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iZ4xrauHZEfG"
      },
      "outputs": [],
      "source": [
        "# Uncomment and run these lines to record video using webcam.\n",
        "\n",
        "# INTERVAL_IN_MS = 1000//WEBCAM_FPS\n",
        "# vid_fps = WEBCAM_FPS\n",
        "# imgs = []\n",
        "# output.register_callback('notebook.get_webcam_video', get_webcam_video)\n",
        "# record_video(INTERVAL_IN_MS, int(RECORDING_TIME_IN_SECONDS*1000/INTERVAL_IN_MS))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mKWkGlEsa3Tg"
      },
      "source": [
        "# Run RepNet"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FUg2vSYhmsT0"
      },
      "outputs": [],
      "source": [
        "print('Running RepNet...')\n",
        "(pred_period, pred_score, within_period,\n",
        " per_frame_counts, chosen_stride) = get_counts(\n",
        "     model,\n",
        "     imgs,\n",
        "     strides=[1,2,3,4],\n",
        "     batch_size=20,\n",
        "     threshold=THRESHOLD,\n",
        "     within_period_threshold=WITHIN_PERIOD_THRESHOLD,\n",
        "     constant_speed=CONSTANT_SPEED,\n",
        "     median_filter=MEDIAN_FILTER,\n",
        "     fully_periodic=FULLY_PERIODIC)\n",
        "print('Visualizing results...')\n",
        "viz_reps(imgs, per_frame_counts, pred_score, interval=1000/VIZ_FPS,\n",
        "         plot_score=PLOT_SCORE)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IYcthjnIJC3P"
      },
      "outputs": [],
      "source": [
        "# Debugging video showing scores, per-frame frequency prediction and\n",
        "# within_period scores.\n",
        "create_count_video(imgs,\n",
        "                   per_frame_counts,\n",
        "                   within_period,\n",
        "                   score=pred_score,\n",
        "                   fps=vid_fps,\n",
        "                   output_file='/tmp/debug_video.mp4',\n",
        "                   delay=1000/VIZ_FPS,\n",
        "                   plot_count=True,\n",
        "                   plot_within_period=True,\n",
        "                   plot_score=True)\n",
        "show_video('/tmp/debug_video.mp4')\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jGFXFGjzzt5U"
      },
      "source": [
        "# Citation\n",
        "\n",
        "If you found our paper/code useful in your research, consider citing our paper:\n",
        "\n",
        "\n",
        "```\n",
        "@InProceedings{Dwibedi_2020_CVPR,\n",
        "author = {Dwibedi, Debidatta and Aytar, Yusuf and Tompson, Jonathan and Sermanet, Pierre and Zisserman, Andrew},\n",
        "title = {Counting Out Time: Class Agnostic Video Repetition Counting in the Wild},\n",
        "booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},\n",
        "month = {June},\n",
        "year = {2020}\n",
        "}\n",
        "```\n",
        "\n",
        "\n",
        "\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "last_runtime": {},
      "name": "repnet_colab.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
